#!/usr/bin/env python
"""
@package ptf

Packet Test Framework (ptf) top level script

To add a new command line option, edit both the config_default dictionary and
the config_setup function. The option's result will end up in the global
oftest.config dictionary.
"""

import sys
import argparse
from subprocess import Popen, PIPE
import logging
import unittest
import time
import os
import importlib
import random
import signal
import fnmatch
import copy
from collections import OrderedDict

root_dir = os.path.dirname(os.path.realpath(__file__))

pydir = os.path.join(root_dir, "src")
if os.path.exists(os.path.join(pydir, "ptf")):
    # Running from source tree
    sys.path.insert(0, pydir)

import ptf
from ptf import config, __version__
import ptf.ptfutils

##@var DEBUG_LEVELS
# Map from strings to debugging levels
DEBUG_LEVELS = {
    "debug": logging.DEBUG,
    "verbose": logging.DEBUG,
    "info": logging.INFO,
    "warning": logging.WARNING,
    "warn": logging.WARNING,
    "error": logging.ERROR,
    "critical": logging.CRITICAL,
}

##@var config_default
# The default configuration dictionary for PTF
config_default = {
    # Miscellaneous options
    "list": False,
    "list_test_names": False,
    "allow_user": False,
    # Test selection options
    "test_spec": "",
    "test_file": None,
    "test_dir": None,
    "test_order": "default",
    "test_order_seed": 0xABA,
    "num_shards": 1,
    "shard_id": 0,
    # Switch connection options
    "platform": "eth",
    "platform_args": None,
    "platform_dir": None,
    "interfaces": [],
    "port_info": {},
    "device_sockets": [],  # when using nanomsg
    # Logging options
    "log_file": "ptf.log",
    "log_dir": None,
    "debug": "verbose",
    "profile": False,
    "profile_file": "profile.out",
    "xunit": False,
    "xunit_dir": "xunit",
    # Test behavior options
    "relax": False,
    "test_params": None,
    "failfast": False,
    "fail_skipped": False,
    "default_timeout": 2.0,
    "default_negative_timeout": 0.1,
    "minsize": 0,
    "random_seed": None,
    "disable_ipv6": False,
    "disable_vxlan": False,
    "disable_erspan": False,
    "disable_geneve": False,
    "disable_mpls": False,
    "disable_nvgre": False,
    "disable_igmp": False,
    "disable_rocev2": False,
    "qlen": 100,
    "test_case_timeout": None,
    # Socket options
    "socket_recv_size": 4096,
    # Other configuration
    "port_map": None,
}


def import_module(root_path, module_name):
    """Try to import a module and class directly instead of the typical
    Python method. Allows for dynamic imports."""
    finder = importlib.machinery.PathFinder()
    module_specs = finder.find_spec(module_name, [root_path])
    return module_specs.loader.load_module()


def config_setup():
    """
    Set up the configuration including parsing the arguments

    @return A pair (config, args) where config is an config
    object and args is any additional arguments from the command line
    """

    usage = "usage: ptf [options] --test-dir TEST_DIR [tests]"

    description = """PTF (Packet Test Framework) is a framework and set of tests
to test a software switch. It is strongly inspired by the OFTest framework, but
it is not tied to OpenFlow. It does not provide any control plane features, but
it is targetted at helping you test a dataplane.

The default configuration assumes that interfaces veth1, veth3, veth5, and veth7
should be connected to the switch's dataplane.

If no positional arguments are given then OFTest will run all tests found in the
--test-dir directory. Otherwise each positional argument is interpreted as
either a test name or a test group name. The union of these will be executed. To
see what groups each test belongs to use the --list option. Tests and groups can
be subtracted from the result by prefixing them with the '^' character.  """

    class ActionInterface(argparse.Action):
        def __call__(self, parser, namespace, values, option_string=None):
            # Parse --interface
            def check_interface(value):
                port_cfg = {}
                sp = ";"
                try:
                    if sp in value:
                        value, p_info = value.split(sp, 1)
                        params = p_info.split(sp)
                        for elem in params:
                            key, val = elem.split("=")
                            port_cfg[key.lower()] = val
                    dev_and_port, interface = value.split("@", 1)
                    dev_and_port = dev_and_port.split("-")
                    if len(dev_and_port) == 1:
                        dev, port = 0, int(dev_and_port[0])
                    elif len(dev_and_port) == 2:
                        dev, port = int(dev_and_port[0]), int(dev_and_port[1])
                    else:
                        raise ValueError("")
                    if port_cfg:
                        getattr(namespace, "port_info")[port] = port_cfg
                except ValueError:
                    parser.error(
                        "incorrect interface syntax (got %s, expected 'port@interface' or 'device-port@interface' \
                                  or providing port configuration using 'device-port@interface;arg=val;arg2=val...' )"
                        % repr(value)
                    )
                return (dev, port, interface)

            assert type(values) is str
            getattr(namespace, self.dest).append(check_interface(values))

    class ActionDeviceSocket(argparse.Action):
        def __call__(self, parser, namespace, values, option_string=None):
            # Parse --device-socket
            def check_device_socket(value):
                def parse_ports(ports):
                    port_set = set()
                    try:
                        ports = ports.strip("{}")
                        ports = ports.split(",")
                    except:
                        raise ValueError("")
                    for port in ports:
                        try:
                            p = int(port)
                            port_set.add(p)
                            continue
                        except:
                            pass
                        try:
                            p1, p2 = port.split("-", 1)
                            p1, p2 = int(p1), int(p2)
                            for p in range(p1, p2 + 1):  # p2 included
                                port_set.add(p)
                        except:
                            raise ValueError("")
                    return port_set

                try:
                    dev_and_port, addr = value.split("@", 1)
                    if dev_and_port[0] == "{":
                        dev, ports = (0, parse_ports(dev_and_port))
                    else:
                        dev_and_port = dev_and_port.split("-", 1)
                        if len(dev_and_port) != 2:
                            raise ValueError("")
                        dev, ports = (
                            int(dev_and_port[0]),
                            parse_ports(dev_and_port[1]),
                        )
                except ValueError:
                    parser.error(
                        "incorrect device-socket syntax (got %s, expected something of the form 0-{1,2,5-8}@<socket addr>)"
                        % repr(value)
                    )
                return (dev, ports, addr)

            assert type(values) is str
            getattr(namespace, self.dest).append(check_device_socket(values))

    class ActionTestDir(argparse.Action):
        def __call__(self, parser, namespace, values, option_string=None):
            assert type(values) is str
            if not os.path.isdir(values):
                parser.error(
                    "invalid value for --test-dir: directory %s does not exist" % values
                )
            setattr(namespace, self.dest, values)

    parser = argparse.ArgumentParser(usage=usage, description=description)

    # Set up default values
    parser.set_defaults(**config_default)

    parser.add_argument("--version", action="version", version=__version__)

    parser.add_argument("test_specs", nargs="*", help="Tests / Groups to run")

    parser.add_argument("--list", action="store_true", help="List all tests and exit")
    parser.add_argument(
        "--list-test-names",
        action="store_true",
        help="List test names matching the test spec and exit",
    )
    parser.add_argument(
        "--allow-user",
        action="store_true",
        help="Proceed even if ptf is not run as root",
    )

    parser.add_argument("--pypath", dest="pypath", action="append")

    parser.add_argument(
        "-pmm",
        "--packet-manipulation-module",
        type=str,
        help="Provide packet manipulation module which should be used "
        "as a 'packet' one for other PTF modules",
    )

    group = parser.add_argument_group("Test selection options")
    group.add_argument("-f", "--test-file", help="File of tests to run, one per line")
    group.add_argument(
        "--test-dir",
        type=str,
        action=ActionTestDir,
        required=True,
        help="Directory containing tests",
    )
    test_order_help = """Choose the order in which the tests will be run:
    default (tests are run in the order in which they appear on command line),
    lexico (use default string ordering on test names),
    rand (random order, use --test-order-seed to specify a seed)
    """
    group.add_argument(
        "--test-order", choices=["default", "lexico", "rand"], help=test_order_help
    )
    group.add_argument(
        "--test-order-seed", type=int, help="Specify seed to randomize test order"
    )
    group.add_argument(
        "--num-shards",
        type=int,
        help="Number of shards that can be used to parallelize test execution",
    )
    group.add_argument(
        "--shard-id", type=int, help="Index of shard (>= 0 and < number of shards)"
    )

    group = parser.add_argument_group("Switch connection options")
    group.add_argument("-P", "--platform", help="Platform module name")
    group.add_argument(
        "-a", "--platform-args", help="Custom arguments for the platform"
    )
    group.add_argument(
        "--platform-dir", type=str, help="Directory containing platform modules"
    )
    group.add_argument(
        "--interface",
        "-i",
        type=str,
        dest="interfaces",
        metavar="INTERFACE",
        action=ActionInterface,
        help="Specify a port number and the dataplane interface to use. May be given multiple times. Example: 1@eth1 or 0-1@eth2 (use eth2 as port 1 of device 0)",
    )
    group.add_argument(
        "--device-socket",
        type=str,
        dest="device_sockets",
        metavar="DEVICE-SOCKET",
        action=ActionDeviceSocket,
        help="Specify the nanomsg socket to use to send / receive packets for a given device, as well as the ports to enable on the device. May be given multiple times. Example: 0-{1,2,5-8}@<socket addr>",
    )

    group = parser.add_argument_group("Logging options")
    group.add_argument("--log-file", help="Name of log file")
    group.add_argument("--log-dir", help="Name of log directory")
    dbg_lvl_names = sorted(list(DEBUG_LEVELS.keys()), key=lambda x: DEBUG_LEVELS[x])
    group.add_argument(
        "--debug",
        choices=dbg_lvl_names,
        help="Debug lvl: debug, info, warning, error, critical",
    )
    group.add_argument(
        "--verbose",
        action="store_const",
        dest="debug",
        const="verbose",
        help="Shortcut for --debug=verbose",
    )
    group.add_argument(
        "-q",
        "--quiet",
        action="store_const",
        dest="debug",
        const="warning",
        help="Shortcut for --debug=warning",
    )
    group.add_argument("--profile", action="store_true", help="Enable Python profiling")
    group.add_argument("--profile-file", help="Output file for Python profiler")
    group.add_argument(
        "--xunit", action="store_true", help="Enable xUnit-formatted results"
    )
    group.add_argument(
        "--xunit-dir", help="Output directory for xUnit-formatted results"
    )

    group = parser.add_argument_group("Test behavior options")
    group.add_argument(
        "--relax",
        action="store_true",
        help="Relax packet match checks allowing other packets",
    )
    group.add_argument(
        "--failfast",
        action="store_true",
        help="Stop running tests as soon as one fails",
    )
    test_params_help = """Set test parameters: [key=val]*;key=val
    """
    group.add_argument("-t", "--test-params", help=test_params_help)
    group.add_argument(
        "--fail-skipped",
        action="store_true",
        help="Return failure if any test was skipped",
    )
    group.add_argument(
        "--default-timeout", type=float, help="Timeout in seconds for most operations"
    )
    group.add_argument(
        "--default-negative-timeout",
        type=float,
        help="Timeout in seconds for negative checks",
    )
    group.add_argument(
        "--minsize", type=int, help="Minimum allowable packet size on the dataplane."
    )
    group.add_argument("--random-seed", type=int, help="Random number generator seed")
    group.add_argument("--disable-ipv6", action="store_true", help="Disable IPv6 tests")
    group.add_argument("--qlen", type=int, help="Default queue length ")
    group.add_argument(
        "--test-case-timeout",
        type=int,
        help="Timeout for each test case, 0 means no timeout",
    )

    group.add_argument(
        "--disable-vxlan",
        action="store_true",
        help="Disable VXLAN (do not import from scapy even if supported)",
    )
    group.add_argument(
        "--disable-geneve",
        action="store_true",
        help="Disable GENEVE (do not import from scapy even if supported)",
    )
    group.add_argument(
        "--disable-erspan",
        action="store_true",
        help="Disable ERSPAN (do not import from scapy even if supported)",
    )
    group.add_argument(
        "--disable-mpls",
        action="store_true",
        help="Disable MPLS (do not import from scapy even if supported)",
    )
    group.add_argument(
        "--disable-nvgre",
        action="store_true",
        help="Disable NVGRE (do not import from scapy even if supported)",
    )
    group.add_argument(
        "--disable-igmp",
        action="store_true",
        help="Disable IGMP (do not import from scapy even if supported)",
    )

    group = parser.add_argument_group("Socket options")
    group.add_argument(
        "--socket-recv-size",
        type=int,
        help="When using raw sockets, specify the size of the buffer used to receive packets with socket.recv.",
    )

    # Might need this if other parsers want command line
    # parser.allow_interspersed_args = False
    args = parser.parse_args()
    if args.pypath:
        for p in args.pypath:
            sys.path.append(p)

    # Convert args from a Namespace to a plain dictionary
    config = config_default.copy()
    for key in config.keys():
        config[key] = getattr(args, key)
    # For selecting the packet manipulation module when running the
    # `ptf` command, the order of precedence is:
    # (1) If the `--packet-manipulation-module` command line option is
    #     present, use its value.
    # (2) Otherwise, if the environment variable
    #     PTF_PACKET_MANIPULATION_MODULE is defined, use its value.
    # (3) Otherwise, use "ptf.packet_scapy"
    pmm_key = "packet_manipulation_module"
    if getattr(args, pmm_key):
        # Then use the value from the command line option.
        config[pmm_key] = getattr(args, pmm_key)
    elif os.getenv("PTF_PACKET_MANIPULATION_MODULE"):
        config[pmm_key] = os.getenv("PTF_PACKET_MANIPULATION_MODULE")
    else:
        config[pmm_key] = "ptf.packet_scapy"

    return (config, args)


def logging_setup(config):
    """
    Set up logging based on config
    """

    logging.getLogger().setLevel(DEBUG_LEVELS[config["debug"]])

    if config["log_dir"] != None:
        if os.path.exists(config["log_dir"]):
            import shutil

            shutil.rmtree(config["log_dir"])
        os.makedirs(config["log_dir"])
    else:
        if os.path.exists(config["log_file"]):
            os.remove(config["log_file"])

    ptf.open_logfile("main")


def xunit_setup(config):
    """
    Set up xUnit output based on config
    """

    if not config["xunit"]:
        return

    if os.path.exists(config["xunit_dir"]):
        import shutil

        shutil.rmtree(config["xunit_dir"])
    os.makedirs(config["xunit_dir"])


def pcap_setup(config):
    """
    Set up dataplane packet capturing based on config
    """

    if config["log_dir"] == None:
        filename = os.path.splitext(config["log_file"])[0] + ".pcap"
        ptf.dataplane_instance.start_pcap(filename)
    else:
        # start_pcap is called per-test in base_tests
        pass


def profiler_setup(config):
    """
    Set up profiler based on config
    """

    if not config["profile"]:
        return

    import cProfile

    profiler = cProfile.Profile()
    profiler.enable()

    return profiler


def profiler_teardown(profiler):
    """
    Tear down profiler based on config
    """

    if not config["profile"]:
        return

    profiler.disable()
    profiler.dump_stats(config["profile_file"])


def load_test_modules(config):
    """
    Load tests from the test_dir directory.

    Test cases are subclasses of unittest.TestCase

    Also updates the _groups member to include "standard" and
    module test groups if appropriate.

    @param config The ptf configuration dictionary
    @returns A dictionary from test module names to tuples of
    (module, dictionary from test names to test classes).
    """

    result = OrderedDict()

    for root, dirs, filenames in os.walk(config["test_dir"]):
        pyfiles = fnmatch.filter(filenames, "[!.]*.py")

        # guarantee that files will be visited in the same order every time tests are loaded
        pyfiles.sort()
        dirs.sort()

        if len(pyfiles) == 0:
            continue

        # Allow tests to import each other
        sys.path.append(root)

        # Iterate over each python file
        for filename in pyfiles:
            modname = os.path.splitext(os.path.basename(filename))[0]

            try:
                if modname in sys.modules:
                    mod = sys.modules[modname]
                else:
                    mod = import_module(root, modname)
            except:
                logging.warning("Could not import file " + filename)
                raise

            # Find all testcases defined in the module
            tests = dict(
                (k, v)
                for (k, v) in mod.__dict__.items()
                if type(v) == type
                and issubclass(v, unittest.TestCase)
                and hasattr(v, "runTest")
            )
            if tests:
                for testname, test in tests.items():
                    # Set default annotation values
                    if not hasattr(test, "_groups"):
                        test._groups = []
                    if not hasattr(test, "_nonstandard"):
                        test._nonstandard = False
                    if not hasattr(test, "_disabled"):
                        test._disabled = False
                    if not hasattr(test, "_testtimeout"):
                        test._testtimeout = None

                    # Put test in its module's test group
                    if not test._disabled:
                        test._groups.append(modname)
                    else:
                        # If the test is disabled, create a group named
                        # disabled and add it too. This is so that
                        # users can conveniently exclude disabled tests
                        # too when including only groups. Eg.
                        # -s "group1 ^disabled"
                        test._groups.append("disabled")

                    # Put test in the standard test group
                    if not test._disabled and not test._nonstandard:
                        test._groups.append("standard")
                        test._groups.append("all")  # backwards compatibility

                result[modname] = (mod, tests)

    return result


def prune_tests(test_specs, test_modules):
    """
    Return tests matching the given test-specs.
    @param test_specs A list of group names or test names.
    @param test_modules Same format as the output of load_test_modules.
    @returns Same format as the output of load_test_modules.
    """
    result = OrderedDict()
    for e in test_specs:
        matched = False

        if e.startswith("^"):
            negated = True
            e = e[1:]
        else:
            negated = False

        for modname, (mod, tests) in test_modules.items():
            for testname, test in tests.items():
                if e in test._groups or e == "%s.%s" % (modname, testname):
                    result.setdefault(modname, (mod, OrderedDict()))
                    if not negated:
                        # if not hasattr(test, "_versions") or version in test._versions:
                        result[modname][1][testname] = test
                    else:
                        if modname in result and testname in result[modname][1]:
                            del result[modname][1][testname]
                            if not result[modname][1]:
                                del result[modname]
                    matched = True

        if not matched and not negated:
            die("test-spec element %s did not match any tests" % e)

    return result


def die(msg, exit_val=1):
    logging.critical(msg)
    sys.exit(exit_val)


def _space_to(n, str):
    """
    Generate a string of spaces to achieve width n given string str
    If length of str >= n, return one space
    """
    spaces = n - len(str)
    if spaces > 0:
        return " " * spaces
    return " "


def test_params_parse(config):
    test_params = config["test_params"]
    if test_params is None:
        logging.debug("No test params were provided with '--test-params' / '-t'")
        return None
    params_str = "class _TestParams:\n    " + test_params
    namespace = {}
    try:
        exec(params_str, namespace)
    except:
        logging.error(
            "Error when parsing test params "
            "(provided with '--test-params' / '-t'). "
            "Make sure you used the correct syntax: "
            '--test-params="[k=v;]*k=v"'
        )
        return None
    params = {}
    logging.debug("Parsed test parameters:")
    for k, v in list(vars(namespace["_TestParams"]).items()):
        if k[:2] != "__":
            params[k] = v
            logging.debug("\t*{}={}".format(k, v))
    logging.debug(
        "If something is missing, make sure you used the correct syntax: "
        '--test-params="[k=v;]*k=v"'
    )
    return params


#
# Main script
#

# Setup global configuration
(new_config, args) = config_setup()
ptf.config.update(new_config)

logging_setup(config)
xunit_setup(config)
logging.info("++++++++ " + time.asctime() + " ++++++++")

# import after logging is configured so that scapy error logs (from importing
# packet.py) are silenced and our own warnings are logged properly.
import ptf.testutils

# Try parsing test params and log them
# We do this before importing the test modules in case test parameters are being
# accessed at test import time.
ptf.testutils.TEST_PARAMS = test_params_parse(config)
# Initiallize port information
ptf.testutils.PORT_INFO = config["port_info"]

test_specs = args.test_specs
if config["test_file"] != None:
    with open(config["test_file"], "r") as f:
        for line in f:
            line, _, _ = line.partition("#")  # remove comments
            line = line.strip()
            if line:
                test_specs.append(line)
if test_specs == []:
    test_specs = ["standard"]

test_modules = load_test_modules(config)

# Check if test list is requested; display and exit if so
if config["list"]:
    mod_count = 0
    test_count = 0
    all_groups = set()
    print(
        """\
Tests are shown grouped by module. If a test is in any groups beyond "standard"
and its module's group then they are shown in parentheses."""
    )
    print()
    print(
        """\
Tests marked with '!' are disabled because they are experimental, special-purpose,
or are too long to be run normally. These are not part of the "standard" test
group or their module's test group."""
    )
    print()
    print("Test List:")
    for modname, (mod, tests) in test_modules.items():
        mod_count += 1
        desc = (mod.__doc__ or "No description").strip().split("\n")[0]
        start_str = "  Module " + mod.__name__ + ": "
        print(start_str + _space_to(22, start_str) + desc)
        for testname, test in list(tests.items()):
            try:
                desc = (test.__doc__ or "").strip()
                desc = desc.split("\n")[0]
            except:
                desc = "No description"
            groups = set(test._groups) - set(["all", "standard", modname])
            all_groups.update(test._groups)
            if groups:
                desc = "(%s) %s" % (",".join(groups), desc)
            if hasattr(test, "_versions"):
                desc = "(%s) %s" % (",".join(sorted(test._versions)), desc)
            start_str = " %s%s %s:" % (
                test._nonstandard and "*" or " ",
                test._disabled and "!" or " ",
                testname,
            )
            if len(start_str) > 22:
                desc = "\n" + _space_to(22, "") + desc
            print(start_str + _space_to(22, start_str) + desc)
            test_count += 1
        print()
    print("%d modules shown with a total of %d tests" % (mod_count, test_count))
    print()
    print("Test groups: %s" % (", ".join(sorted(all_groups))))

    sys.exit(0)

test_modules = prune_tests(test_specs, test_modules)

# Check if test list is requested; display and exit if so
if config["list_test_names"]:
    for modname, (mod, tests) in test_modules.items():
        for testname, test in tests.items():
            print("%s.%s" % (modname, testname))
    sys.exit(0)

# Generate the test suite
test_suite = []
for modname, (mod, tests) in test_modules.items():
    for testname, test in tests.items():
        test_suite.append(test())

if config["shard_id"] < 0 or config["shard_id"] >= config["num_shards"]:
    die("shard id should be equal or greater than 0 and lower than number of shards")
test_suite = test_suite[config["shard_id"] :: config["num_shards"]]

if config["test_order"] == "lexico":
    test_suite.sort()
elif config["test_order"] == "rand":
    seed = config["test_order_seed"]
    random.seed(seed)
    random.shuffle(test_suite)


if config["platform_dir"] is None:
    from ptf import platforms

    config["platform_dir"] = os.path.dirname(os.path.abspath(platforms.__file__))

# Allow platforms to import each other
sys.path.append(config["platform_dir"])

# Load the platform module
platform_name = config["platform"]
logging.info("Importing platform: " + platform_name)

# TODO(antonin): put this check in platforms/nn.py ?
if platform_name == "nn":
    try:
        import nnpy
    except:
        die("Cannot use 'nn' platform if nnpy package is not installed")

platform_mod = None
try:
    platform_mod = import_module(config["platform_dir"], platform_name)
except:
    logging.warn("Failed to import " + platform_name + " platform module")
    raise

try:
    platform_mod.platform_config_update(config)
except:
    logging.warn("Could not run platform host configuration")
    raise

if config["port_map"] is None:
    die("Interface port map was not defined by the platform. Exiting.")

logging.debug("Configuration: " + str(config))
logging.info("port map: " + str(config["port_map"]))

ptf.ptfutils.default_timeout = config["default_timeout"]
ptf.ptfutils.default_negative_timeout = config["default_negative_timeout"]
ptf.testutils.MINSIZE = config["minsize"]

if os.getuid() != 0 and not config["allow_user"] and platform_name != "nn":
    die("Super-user privileges required. Please re-run with sudo or as root.")

if config["random_seed"] is not None:
    logging.info("Random seed: %d" % config["random_seed"])
    random.seed(config["random_seed"])
else:
    # Generate random seed and report to log file
    seed = random.randrange(100000000)
    logging.info("Autogen random seed: %d" % seed)
    random.seed(seed)

# Remove python's signal handler which raises KeyboardError. Exiting from an
# exception waits for all threads to terminate which might not happen.
signal.signal(signal.SIGINT, signal.SIG_DFL)

if __name__ == "__main__":
    profiler = profiler_setup(config)

    # Set up the dataplane
    ptf.dataplane_instance = ptf.dataplane.DataPlane(config)
    pcap_setup(config)
    for port_id, ifname in config["port_map"].items():
        device, port = port_id
        ptf.dataplane_instance.port_add(ifname, device, port)

    logging.info("*** TEST RUN START: " + time.asctime())
    if config["xunit"]:
        try:
            import xmlrunner  # fail-fast if module missing
        except ImportError as ex:
            ptf.dataplane_instance.kill()
            profiler_teardown(profiler)
            raise ex
        runner = xmlrunner.XMLTestRunner(
            output=config["xunit_dir"], outsuffix="", verbosity=2
        )
    else:
        runner = unittest.TextTestRunner(verbosity=2)
    test_case_timeout = config["test_case_timeout"]
    run_failures = []
    run_errors = []
    run_timeouts = []
    from ptf.ptfutils import Timeout

    for t in test_suite:
        if t._testtimeout is not None:
            this_test_timeout = t._testtimeout
        else:
            this_test_timeout = test_case_timeout

        if this_test_timeout:
            with Timeout(this_test_timeout):
                result = runner.run(t)
        else:
            result = runner.run(t)

        run_failures.extend(result.failures)
        run_errors.extend(result.errors)

        if this_test_timeout:
            for case in result.errors:
                traceback_str = case[1]
                # TODO: hacky? could not think of a better way
                if "raise Timeout.TimeoutError()" in traceback_str:
                    logging.info("Test case failed because of timeout")
                    run_timeouts.append(case)

        if config["failfast"] and not result.wasSuccessful():
            logging.info("Test failed and failfast mode enabled, stopping")
            break

    ptf.open_logfile("main")
    if ptf.testutils.skipped_test_count > 0:
        ts = " tests"
        if ptf.testutils.skipped_test_count == 1:
            ts = " test"
        logging.info("Skipped " + str(ptf.testutils.skipped_test_count) + ts)
        print("Skipped " + str(ptf.testutils.skipped_test_count) + ts)
    logging.info("*** TEST RUN END  : " + time.asctime())

    # Shutdown the dataplane
    ptf.dataplane_instance.stop_pcap()  # no-op is pcap not started
    ptf.dataplane_instance.kill()
    ptf.dataplane_instance = None

    profiler_teardown(profiler)

    if run_failures or run_errors:
        print()
        print("******************************************")
        print("ATTENTION: SOME TESTS DID NOT PASS!!!")
        if (not config["xunit"]) and run_failures:
            print()
            print("The following tests failed:")
            print(", ".join([f[0].__class__.__name__ for f in run_failures]))
        if (not config["xunit"]) and run_errors:
            print()
            print("The following tests errored:")
            print(", ".join([f[0].__class__.__name__ for f in run_errors]))
        if (not config["xunit"]) and run_timeouts:
            print()
            print("The following tests errored because of a timeout:")
            print(", ".join([f[0].__class__.__name__ for f in run_timeouts]))
        print()
        print("******************************************")
        # exit(1) hangs sometimes
        sys.stdout.flush()
        sys.stderr.flush()
        os._exit(1)
    if ptf.testutils.skipped_test_count > 0 and config["fail_skipped"]:
        print()
        print("******************************************")
        print("ATTENTION: %d TESTS WERE SKIPPED!!!", ptf.testutils.skipped_test_count)
        print("******************************************")
        print()
        sys.stdout.flush()
        sys.stderr.flush()
        os._exit(1)
